syntax = "proto3";

package tensorflow.data;

import "tensorflow/core/data/service/common.proto";
import "tensorflow/core/framework/dataset.proto";

message ProcessTaskRequest {
  TaskDef task = 1;
}

message ProcessTaskResponse {}

message GetElementRequest {
  // The task to fetch an element from.
  int64 task_id = 1;
  // Optional index to identify the consumer.
  oneof optional_consumer_index {
    int64 consumer_index = 2;
  }
  // Optional round index, indicating which round of round-robin the consumer
  // wants to read from. This is used to keep consumers in sync.
  oneof optional_round_index {
    int64 round_index = 3;
  }
  // Whether the previous round was skipped. This information is needed by the
  // worker to recover after restarts.
  bool skipped_previous_round = 4;
  // Whether to skip the round if data isn't ready fast enough.
  bool allow_skip = 5;
  // The trainer ID used to read elements from a multi-trainer cache. This cache
  // enables sharing data across concurrent training iterations. If set, this
  // request will read the data requested by other trainers, if available.
  string trainer_id = 6;
}

message GetElementResponse {
  // The produced element.
  oneof element {
    CompressedElement compressed = 3;
    UncompressedElement uncompressed = 5;
  }
  // The element's index within the task it came from.
  int64 element_index = 6;
  // Boolean to indicate whether the iterator has been exhausted.
  bool end_of_sequence = 2;
  // Indicates whether the round was skipped.
  bool skip_task = 4;
}

// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
message GetWorkerTasksRequest {}

message GetWorkerTasksResponse {
  repeated TaskInfo tasks = 1;
}

message GetSnapshotTaskProgressesRequest {}

message GetSnapshotTaskProgressesResponse {
  repeated SnapshotTaskProgress snapshot_task_progresses = 1;
}

service WorkerService {
  // Processes a task for a dataset, making elements available to clients.
  rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);

  // Gets the next dataset element.
  rpc GetElement(GetElementRequest) returns (GetElementResponse);

  // Gets the tasks currently being executed by the worker.
  rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse);

  // Gets the progresses of the snapshot tasks currently being executed by the
  // worker.
  rpc GetSnapshotTaskProgresses(GetSnapshotTaskProgressesRequest)
      returns (GetSnapshotTaskProgressesResponse);
}
